{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "import pathlib\n",
    "import time\n",
    "from tqdm import tqdm\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "tqdm.pandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "root = './data/cleaning_data_95.csv'\n",
    "data = pd.read_csv(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>item</th>\n",
       "      <th>labels</th>\n",
       "      <th>subject</th>\n",
       "      <th>topic</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>加 试题 阅读 材料 完成 下列 问题 材料 一辽 中南 工业区 左图 北部湾 经济区 右图...</td>\n",
       "      <td>高中 地理 宇宙中的地球 工业区位因素 地球运动的地理意义</td>\n",
       "      <td>地理</td>\n",
       "      <td>宇宙中的地球</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>精子 变形 叙述 错误 线粒体 退化 消失 高尔基体 发育 头部 顶体 中心 体 演变 精子...</td>\n",
       "      <td>高中 生物 现代生物技术专题 生物技术在其他方面的应用</td>\n",
       "      <td>生物</td>\n",
       "      <td>现代生物技术专题</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>哈尔滨 北京 上海 广州 四地 自转 速度 关系 自转 线速度 相同 自转角 速度 相同 线...</td>\n",
       "      <td>高中 地理 宇宙中的地球 地球运动的基本形式</td>\n",
       "      <td>地理</td>\n",
       "      <td>宇宙中的地球</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>太阳活动 高峰期 时 色球层 上 太阳黑子 增多 光球 层上 耀斑 大 爆发 黑子 活动 耀...</td>\n",
       "      <td>高中 地理 宇宙中的地球 太阳对地球的影响</td>\n",
       "      <td>地理</td>\n",
       "      <td>宇宙中的地球</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>年月日 国家 主席 习近平 伦敦 唐宁街 首相府 英国首相 卡梅伦 举行会谈 双方 决定 共...</td>\n",
       "      <td>高中 历史 现代史 海峡两岸关系的发展</td>\n",
       "      <td>历史</td>\n",
       "      <td>现代史</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                item  \\\n",
       "0  加 试题 阅读 材料 完成 下列 问题 材料 一辽 中南 工业区 左图 北部湾 经济区 右图...   \n",
       "1  精子 变形 叙述 错误 线粒体 退化 消失 高尔基体 发育 头部 顶体 中心 体 演变 精子...   \n",
       "2  哈尔滨 北京 上海 广州 四地 自转 速度 关系 自转 线速度 相同 自转角 速度 相同 线...   \n",
       "3  太阳活动 高峰期 时 色球层 上 太阳黑子 增多 光球 层上 耀斑 大 爆发 黑子 活动 耀...   \n",
       "4  年月日 国家 主席 习近平 伦敦 唐宁街 首相府 英国首相 卡梅伦 举行会谈 双方 决定 共...   \n",
       "\n",
       "                          labels subject     topic  \n",
       "0  高中 地理 宇宙中的地球 工业区位因素 地球运动的地理意义      地理    宇宙中的地球  \n",
       "1    高中 生物 现代生物技术专题 生物技术在其他方面的应用      生物  现代生物技术专题  \n",
       "2         高中 地理 宇宙中的地球 地球运动的基本形式      地理    宇宙中的地球  \n",
       "3          高中 地理 宇宙中的地球 太阳对地球的影响      地理    宇宙中的地球  \n",
       "4            高中 历史 现代史 海峡两岸关系的发展      历史       现代史  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "词汇表为：\n",
      " 20000\n"
     ]
    }
   ],
   "source": [
    "corpus=data['item']\n",
    "\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "#使用tfidf统计词频特征 来作为文本特征\n",
    "tfidf_vect = TfidfVectorizer(max_features=20000,min_df=3,ngram_range=(2,4))\n",
    "vec = tfidf_vect.fit_transform(corpus)\n",
    "\n",
    "print('词汇表为：\\n', len(tfidf_vect.vocabulary_))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import MultiLabelBinarizer,LabelEncoder\n",
    "mlb = MultiLabelBinarizer()\n",
    "lb = LabelEncoder()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_4 = lb.fit_transform(data['subject'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "x_train,x_test,y_train,y_test = train_test_split(vec,y_4,test_size = 0.2,random_state = 2020)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.naive_bayes import GaussianNB,MultinomialNB,ComplementNB,BernoulliNB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "gnb = GaussianNB()        #高斯分布贝叶斯    假设x的先验概率服从高斯分布\n",
    "clf = MultinomialNB()     #多项式分布贝叶斯  假设x的先验概率服从多项式分布\n",
    "bnb = BernoulliNB()       #伯努利分布贝叶斯  假设x的先验概率服从二元伯努利分布\n",
    "cnb = ComplementNB()      #补充朴素贝叶斯    Mnb的改编，适用于不平衡数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wall time: 12.7 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "gnb.fit(x_train.toarray(),y_train)  #高斯的输入需转换为np数组\n",
    "\n",
    "clf.fit(x_train,y_train)\n",
    "\n",
    "bnb.fit(x_train,y_train)\n",
    "\n",
    "cnb.fit(x_train,y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "##预测\n",
    "pre1 = gnb.predict(x_test.toarray())\n",
    "pre2 = clf.predict(x_test)\n",
    "pre3 = bnb.predict(x_test)\n",
    "pre4 = cnb.predict(x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gnb Result：               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.91      0.94      0.92       532\n",
      "           1       0.98      0.94      0.96       978\n",
      "           2       0.93      0.98      0.95       371\n",
      "           3       0.99      0.99      0.99      2635\n",
      "\n",
      "    accuracy                           0.97      4516\n",
      "   macro avg       0.95      0.96      0.96      4516\n",
      "weighted avg       0.97      0.97      0.97      4516\n",
      "\n",
      "clf Result：               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.00      0.00      0.00       532\n",
      "           1       0.94      0.05      0.09       978\n",
      "           2       0.97      0.18      0.30       371\n",
      "           3       0.60      1.00      0.75      2635\n",
      "\n",
      "    accuracy                           0.61      4516\n",
      "   macro avg       0.63      0.31      0.28      4516\n",
      "weighted avg       0.63      0.61      0.48      4516\n",
      "\n",
      "bnb Result：               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.55      0.99      0.71       532\n",
      "           1       0.95      0.94      0.94       978\n",
      "           2       0.97      0.73      0.83       371\n",
      "           3       1.00      0.88      0.94      2635\n",
      "\n",
      "    accuracy                           0.89      4516\n",
      "   macro avg       0.87      0.88      0.86      4516\n",
      "weighted avg       0.93      0.89      0.90      4516\n",
      "\n",
      "cnb Result：               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.90      0.18      0.30       532\n",
      "           1       0.99      0.56      0.72       978\n",
      "           2       0.96      0.67      0.79       371\n",
      "           3       0.73      1.00      0.84      2635\n",
      "\n",
      "    accuracy                           0.78      4516\n",
      "   macro avg       0.90      0.60      0.66      4516\n",
      "weighted avg       0.83      0.78      0.75      4516\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "print('gnb Result：',classification_report(y_test,pre1))\n",
    "print('clf Result：',classification_report(y_test,pre2))\n",
    "print('bnb Result：',classification_report(y_test,pre3))\n",
    "print('cnb Result：',classification_report(y_test,pre4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 对比一下其他分类器"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**LR**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result：               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.79      0.18      0.29       532\n",
      "           1       0.99      0.38      0.55       978\n",
      "           2       0.97      0.33      0.49       371\n",
      "           3       0.68      1.00      0.81      2635\n",
      "\n",
      "    accuracy                           0.71      4516\n",
      "   macro avg       0.86      0.47      0.53      4516\n",
      "weighted avg       0.78      0.71      0.66      4516\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "lr = LogisticRegression(C=2)\n",
    "lr.fit(x_train,y_train)\n",
    "pre = lr.predict(x_test)\n",
    "print('Result：',classification_report(y_test,pre))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**SVM**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result：               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.82      0.18      0.29       532\n",
      "           1       0.99      0.34      0.50       978\n",
      "           2       0.97      0.43      0.60       371\n",
      "           3       0.67      1.00      0.80      2635\n",
      "\n",
      "    accuracy                           0.71      4516\n",
      "   macro avg       0.86      0.49      0.55      4516\n",
      "weighted avg       0.78      0.71      0.66      4516\n",
      "\n",
      "Wall time: 34.3 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "from sklearn.svm import SVC\n",
    "svc = SVC(kernel='rbf',C=1)\n",
    "svc.fit(x_train,y_train)\n",
    "pre = svc.predict(x_test)\n",
    "print('Result：',classification_report(y_test,pre))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**RandomForest**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result：               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.82      0.88      0.85       532\n",
      "           1       0.94      0.86      0.90       978\n",
      "           2       0.97      0.81      0.88       371\n",
      "           3       0.94      0.98      0.96      2635\n",
      "\n",
      "    accuracy                           0.93      4516\n",
      "   macro avg       0.92      0.88      0.90      4516\n",
      "weighted avg       0.93      0.93      0.93      4516\n",
      "\n",
      "Wall time: 10.7 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "rf = RandomForestClassifier()\n",
    "rf.fit(x_train,y_train)\n",
    "pre = rf.predict(x_test)\n",
    "print('Result：',classification_report(y_test,pre))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**XGboost**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result：               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.77      0.67      0.72       532\n",
      "           1       0.91      0.52      0.66       978\n",
      "           2       0.97      0.50      0.66       371\n",
      "           3       0.78      0.98      0.87      2635\n",
      "\n",
      "    accuracy                           0.80      4516\n",
      "   macro avg       0.86      0.67      0.73      4516\n",
      "weighted avg       0.82      0.80      0.79      4516\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import xgboost as xgb\n",
    "model = xgb.XGBClassifier()\n",
    "model.fit(x_train,y_train)\n",
    "pre = model.predict(x_test)\n",
    "print('Result：',classification_report(y_test,pre))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 对比几种分类器，在4分类的情况下，GaussianNB的分类效果最好，其次是RandomForest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
